{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0261b70c-32e2-4482-9f75-c0dfbf4f9285",
   "metadata": {},
   "source": [
    "### Imagination continued\n",
    "\n",
    "This notebook contains code for getting samples from a given class with Monte Carlo sampling\n",
    "\n",
    "Tested with tensorflow 2.11.0 and Python 3.10.9."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bec94eeb-27ca-4bc5-af6f-bac6d96b34d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import zipfile\n",
    "import os\n",
    "import psutil\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from tensorflow import keras\n",
    "from PIL import Image, ImageOps\n",
    "import tensorflow as tf\n",
    "import tensorflow_datasets as tfds\n",
    "import tensorflow.keras.backend as K\n",
    "import matplotlib.pyplot as plt\n",
    "from tensorflow.keras import Model, Sequential, metrics, optimizers, layers\n",
    "from tensorflow.python.framework.ops import disable_eager_execution\n",
    "from utils import load_tfds_dataset\n",
    "from generative_model import encoder_network_large, decoder_network_large, VAE\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.svm import SVC\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.pipeline import make_pipeline\n",
    "\n",
    "tf.keras.utils.set_random_seed(123)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6c3494d0-27da-4c1f-817d-05de7d11eaf1",
   "metadata": {},
   "source": [
    "#### Load model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b0c6a3a-b211-49a0-b600-832844bb37a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "K.set_image_data_format('channels_last')\n",
    "\n",
    "latent_dim = 20\n",
    "input_shape = (64, 64, 3)\n",
    "\n",
    "encoder, z_mean, z_log_var = encoder_network_large(input_shape, latent_dim)\n",
    "decoder = decoder_network_large(latent_dim)\n",
    "\n",
    "encoder.load_weights(\"model_weights/shapes3d_encoder.h5\")\n",
    "decoder.load_weights(\"model_weights/shapes3d_decoder.h5\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b3e949da-3801-44ec-b6c9-a772c4ad44eb",
   "metadata": {},
   "source": [
    "#### Train classifier\n",
    "\n",
    "(Identical method to classifier used to measure decoding accuracy over time.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce9e87d4-7f35-4e6c-a41d-d193cb1402f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_ds, test_ds, train_labels, test_labels = load_tfds_dataset('shapes3d', labels=True, \n",
    "                                                                 key_dict= {'shapes3d': 'label_shape'})\n",
    "train_ds = train_ds / 255\n",
    "test_ds = test_ds / 255"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7617f8b3-7a05-4ee8-8e20-0b40b85e039b",
   "metadata": {},
   "outputs": [],
   "source": [
    "latents = encoder.predict(test_ds)\n",
    "clf = make_pipeline(StandardScaler(), SVC(probability=True))\n",
    "clf.fit(latents[0], test_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a40644ed-1acf-4f55-b713-53063ce64788",
   "metadata": {},
   "source": [
    "#### Sampling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d17ad26-7a4d-4cf2-a5ed-626165b06bb7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# perform Monte Carlo sampling to generate new samples\n",
    "n_samples = 10000  # number of samples to generate\n",
    "samples = np.random.normal(loc=0, scale=1, size=(n_samples, latent_dim))\n",
    "\n",
    "# predict the labels of the new samples\n",
    "prob = clf.predict_proba(samples) \n",
    "\n",
    "class_idx = 0\n",
    "\n",
    "# get the probabilities for the desired class\n",
    "class_probabilities = prob[:, class_idx]\n",
    "# sort the samples by the probabilities, in descending order \n",
    "sorted_samples = samples[np.argsort(class_probabilities)[::-1]]\n",
    "# decode the high-scoring samples into images\n",
    "decoded_images = decoder.predict(sorted_samples)\n",
    "\n",
    "n = 20 # number of images to display\n",
    "plt.figure(figsize=(40, 4))\n",
    "\n",
    "for i in range(n):\n",
    "    ax = plt.subplot(2, n, i + 1)\n",
    "    plt.imshow(decoded_images[i].reshape(64, 64, 3))\n",
    "    plt.gray()\n",
    "    ax.get_xaxis().set_visible(False)\n",
    "    ax.get_yaxis().set_visible(False)\n",
    "\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
